# -*- coding: utf-8 -*-
"""
Created on Fri Jun 17 10:53:57 2022

@author: LSY
"""

import numpy as np
from numpy import random
import timeit
import mpctools as mpc
import scipy.io as spio # load .mat file

np.set_printoptions(precision=4)
np.set_printoptions(linewidth=400)

# ---------------------------------------------------------------------
# Simulation parameters
# ---------------------------------------------------------------------
Nx = 30  # Number of differential states
Ny = 30  # Number of outputs
Nw = Nx 
Nv = Ny

sigma_w = 0.01
sigma_v = 0.01
w = sigma_w*np.random.randn(Nx,Nw)
v = sigma_v*np.random.randn(Nx,Nv)

#### sensor set(measurements): assume that all states have measurements
def measure(x):
    c_extend = np.ones([Ny,Nx])
    y = np.dot(c_extend,x)
    return y

# ---------------------------------------------------------------------
# Model setup
# ---------------------------------------------------------------------
mat_A = spio.loadmat('A.mat', squeeze_me=True)
A = mat_A['A']
mat_C = spio.loadmat('C.mat', squeeze_me=True)
C = mat_C['C']
mat_E = spio.loadmat('E.mat', squeeze_me=True)
E = mat_E['E']
mat_F = spio.loadmat('F.mat', squeeze_me=True)
F = mat_F['F']

print("Calculate the sensitive matrix with all measurements ......")
Tstart = timeit.default_timer()


def SenMatrix(A, C):
    Nx = A.shape[0]
    Ny = C.shape[0]
    AC=np.zeros((Nx*Ny,Nx))
    AC[0:Ny,:] = C
    for i in range(Nx-1): # A.shape[0]: read the number of rows 
        AC[(i+1)*Ny:(i+2)*Ny,:] = AC[(i)*Ny:(i+1)*Ny,:].dot(A)
    return AC


sm_all = SenMatrix(A,C) 
rank = np.linalg.matrix_rank(sm_all)

alpha = 3
size = Nx
def SensAnal(Sall,Nx,Ny,sigma_w,sigma_v,size): # size is the value of window size in FIE and MHE
    rank = np.linalg.matrix_rank(Sall) # RANK_pzc(Sall)    
    print(np.linalg.matrix_rank(Sall))

    rank_S = np.zeros([1,1],dtype=int)
    rank_S[0,0] = rank
    # Variable selection procudure
    Rl = np.zeros([Nx+1,(size+1)*Ny,Nx])
    Zl = np.zeros([Nx,(size+1)*Ny,Nx])
    SumCol = np.zeros([Nx,Nx])
    Flag = np.zeros([Nx,1],dtype=int)
    
    Rl[0,:,:] = Sall
    
    Xl = np.zeros([(size+1)*Ny,Nx])
    
    degalpha =  alpha*np.sqrt(sigma_w**2+sigma_v**2) # np.sqrt((size+1)*Ny) *
    for i in range(rank):
        for j in range(Nx):
            for k in range((size+1)*Ny):
                SumCol[i,j] = SumCol[i,j] + Rl[i,k,j]**2
            SumCol[i,j] = np.sqrt(SumCol[i,j])
        Flag[i,0] = 0
        for i1 in range(Nx):
            if SumCol[i,Flag[i,0]] < SumCol[i,i1]:
               Flag[i,0] = i1
                            
    # A prescribed value for sensitivity to terminate
        if SumCol[i,Flag[i,0]]  < degalpha:
            rank = i
            break
        else:
            Xl[:,i] = Sall[:,Flag[i,0]]
                                
    # select all variables or singular Xl: Terminate the selection
        if i == rank-1:
            break
        else:
            rank_xx = np.linalg.matrix_rank(mpc.mtimes(Xl[:,0:i+1].T,Xl[:,0:i+1]))
            if rank_xx == (i+1):
                lsy=np.linalg.inv(mpc.mtimes(Xl[:,0:i+1].T,Xl[:,0:i+1]))
                Zl[i,:,:] = mpc.mtimes(mpc.mtimes(mpc.mtimes(Xl[:,0:i+1],lsy),
                                            Xl[:,0:i+1].T),Sall)
                Rl[i+1,:,:] = Sall-Zl[i,:,:]
    # Eliminate residual non-zero values
                for i1 in range(i+1):
                    for i2 in range((size+1)*Ny):
                        Rl[i+1,i2,Flag[i1,0]] = 0
            else:
                rank = i+1
                break
    maxcol = np.amax(SumCol,axis=1)
    Js = np.sum(maxcol)
#    Js = np.linalg.norm(maxcol)
####### if remove one sensor from set, the rank is not full: let degree = 0 to remain this sensor.      
    if rank_S < Nx:
        Js = Js*0
    Result = np.concatenate((rank_S,Js),axis = None)  #,Flag,rank_S
    return Result


mat_price = spio.loadmat('Price.mat', squeeze_me=True)
price = mat_price['Price']

z = np.ones((1,Ny))
cost = np.dot(z,price)

Tend = timeit.default_timer()
#print("The sensitive matrix sm = \n",sm_all)
print('Cost time:', str(Tend-Tstart),'s')
####################################################################


#### The sensitive degree with all the measurments:
Tstart = timeit.default_timer()
Flag = SensAnal(sm_all,Nx,Ny,sigma_w,sigma_v,size-1)   
Rank_Sen = Flag[0]
print('Rank of Sensiti:', Rank_Sen)
Sen_deg = Flag[-1]
print("Sen degree:",Sen_deg)
Tend = timeit.default_timer()
print('Cost time:', str(Tend-Tstart),'s')
########################################################


############ If these measurements have been removed:
M_remain = list(range(Ny))
R_known = []  
#R_known = np.array([0,1,11,13,14]) # the removed sensors
N_known = 0  # the number of removed sensors

z_curr = np.zeros((1,Ny))
z_curr[:] = z[:]
z_curr[:,R_known] = 0
cost = np.dot(z_curr,price)
CPI = Sen_deg/cost

print("These measurements have been removed priori:\n", R_known)
for i in R_known:
    M_remain.remove(i)
Ny = Ny - N_known
Nr_max = Nx - N_known

print("Remain these measurements:\n", M_remain)
print("cost:", cost)  
print("CPI:", CPI)

###### calculate the rank and degree
#tf = np.eye(Nsim*(23))
#T_list = []
#        
#for ni in range(Nsim):        # 删除Nsim个y
#    for nj in R_known:
#        T_list.append(nj+ni*(23))
#tf = np.delete(tf,T_list,axis=0)       
#SM_I = np.dot(tf,sm_all)
#Flag = SensAnal(SM_I,Nx,Ny,sigma_w,sigma_v,Nsim-1,rank_xp)
#Rank_Sen = Flag[0]
#print('Rank of Sensiti:', Rank_Sen)
#Sen_deg = Flag[-1]
#print("Sen degree:",Sen_deg)

################################ begin remove #####################################################
print("------------------------------------------Begin Remove-------------------------------------")

Rank_Sen_final = np.zeros((1,1))
Sen_deg_final = np.zeros((1,1))
cost_final =  np.zeros((1,1))
CPI_final=  np.zeros((1,1))

Rank_Sen_final[0,0] = Rank_Sen
Sen_deg_final[0,0] = Sen_deg
cost_final[0,0] = cost
CPI_final[0,0] = CPI


remove_infor = np.zeros((Nr_max,Ny))

for k in range(Nr_max):
    print("-----------------------------------------------------------------------------------")
    print(k+N_known+1,"- th remove", "(remain:", Nx-N_known,"measurements)")
    Rank_Sen_remain_max = np.zeros((1,1))
    M_remain_try = []
    M_remain_copy = []    
    R_known_try = []
    Ny = Ny - 1
    Sen_deg_opt = np.zeros((1,1))
    cost_opt = np.zeros((1,1))    
    CPI_max = np.zeros((1,1))
    
    for i in M_remain:
        Tstart = timeit.default_timer()
        R_known_try[:] = R_known
        R_known_try.append(i)
        M_remove_try = i;
        
               
        REs = open("Results.txt", "a+")
        print('Try to remove:', M_remove_try)
        print('Try to remove:', M_remove_try, file = REs)
        M_remain_try[:] = M_remain
        M_remain_try.remove(i)
          
        print('Remain these measurements', M_remain_try)
        print('Remain these measurements', M_remain_try, file = REs)
        
        TT = np.eye(size*(30))
        T_list = []
        
        for ni in range(size):       
            for nj in R_known_try:
                T_list.append(nj+ni*(30))
        TT = np.delete(TT,T_list,axis=0)       
        SM_I = np.dot(TT,sm_all)

        Flag = SensAnal(SM_I,Nx,Ny,sigma_w,sigma_v,size-1)
        

        Rank_Sen_remain = Flag[0]
        print('Rank of Sensiti:', Rank_Sen_remain)
        print('Rank of Sensiti:', Rank_Sen_remain, file = REs)
#        Sen_state_sort = Flag[1:-1]
#        print("Sen state sort:",Sen_state_sort)
        Sen_deg = Flag[-1]
        
        z_curr[:] = z[:]
        z_curr[:,R_known_try] = 0
        cost = np.dot(z_curr,price)
        CPI = Sen_deg/cost**3
        
        
        remove_infor[k,i] = CPI # record the CPI in this matrix, in order to plot figure by Matlab
        
        print("Sen degree:",Sen_deg)
        print("Sen degree:",Sen_deg, file = REs)
        print("cost:", cost)  
        print("CPI:", CPI)
        print("CPI:",CPI, file = REs)
        
#        if Rank_Sen_remain_max < Rank_Sen_remain:
#            Rank_Sen_remain_max[0,0] = Rank_Sen_remain 
        
        if CPI >= CPI_max:
           Sen_deg_opt[:] = Sen_deg
           cost_opt[:] = cost
           CPI_max[:] = CPI
           M_remove = M_remove_try
           M_remain_copy[:] = M_remain_try
           Rank_Sen_remain_max[0,0] = Rank_Sen_remain
           
        Tend = timeit.default_timer()
        print('Cost time:', str(Tend-Tstart),'s')
        print("-----------------------------------------------------------------------------------")
        print("-----------------------------------------------------------------------------------", file = REs)
 
       
    if Rank_Sen_remain_max < Nx:  
        print("stop trying")
        print("stop trying", file = REs)
        break
    else:
        if cost_opt <= 62:
            Rank_Sen_final[0,0] = Rank_Sen_remain_max
            Sen_deg_final[0,0] = Sen_deg_opt
            cost_final[0,0] = cost_opt
            CPI_final[0,0] = CPI_max
            
            print("new removed:",M_remove)
            M_remain = M_remain_copy
            R_known.append(M_remove)
            break # End code when the cost condition is met; no break: end code only rank condition
        else:
            Rank_Sen_final[0,0] = Rank_Sen_remain_max
            Sen_deg_final[0,0] = Sen_deg_opt
            cost_final[0,0] = cost_opt
            CPI_final[0,0] = CPI_max
            
        print("new removed:",M_remove)
        M_remain = M_remain_copy
        R_known.append(M_remove)
        print("removed measurements:\n", R_known)
        print("removed measurements:\n", R_known, file = REs)
        print("remains measurements:\n", M_remain)   
        print("remains measurements:\n", M_remain, file = REs) 

        
print("All removed measurements:\n", R_known)
print("All removed measurements:\n", R_known, file = REs)
print("Finally remains:",M_remain) 
print("Rank:",Rank_Sen_final) 
print("Rank:",Rank_Sen_final, file = REs)
print("Degree:",Sen_deg_final)
print("Degree:",Sen_deg_final, file = REs)
print("cost:",cost_final)
print("cost:",cost_final, file = REs)
print("CPI:",CPI_final)
print("CPI:",CPI_final, file = REs)
REs.close()